import numpy as np
import random

from d4rl.pointmaze import q_iteration
from d4rl.pointmaze.gridcraft import grid_env
from d4rl.pointmaze.gridcraft import grid_spec

MAZE = \
"###################\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOOOOOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"####O#########O####\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"#OOOOOOOOOOOOOOOOO#\\"+\
"#OOOOOOOO#OOOOOOOO#\\"+\
"###################\\"


# NLUDR -> RDLU
TRANSLATE_DIRECTION = {
        0: None,
        1: 3,#3,
        2: 1,#1,
        3: 2,#2,
        4: 0,#0,
}

RIGHT = 1
LEFT = 0
FORWARD = 2

class FourRoomController(object):
    def __init__(self):
        self.env = grid_env.GridEnv(grid_spec.spec_from_string(MAZE))
        self.reset_locations = list(zip(*np.where(self.env.gs.spec == grid_spec.EMPTY)))

    def sample_target(self):
        return random.choice(self.reset_locations)

    def set_target(self, target):
        self.target = target
        self.env.gs[target] = grid_spec.REWARD
        self.q_values = q_iteration.q_iteration(env=self.env, num_itrs=32, discount=0.99)
        self.env.gs[target] = grid_spec.EMPTY

    def get_action(self, pos, orientation):
        if tuple(pos) == tuple(self.target):
            done = True
        else:
            done = False
        env_pos_idx = self.env.gs.xy_to_idx(pos)
        qvalues = self.q_values[env_pos_idx]
        direction = TRANSLATE_DIRECTION[np.argmax(qvalues)]
        #tgt_pos, _ = self.env.step_stateless(env_pos_idx, np.argmax(qvalues))
        #tgt_pos = self.env.gs.idx_to_xy(tgt_pos)
        #print('\tcmd_dir:', direction, np.argmax(qvalues), qvalues, tgt_pos)
        #infos = {}
        #infos['tgt_pos'] = tgt_pos
        if orientation == direction or direction == None:
            return FORWARD, done
        else:
            return get_turn(orientation, direction), done

#RDLU
TURN_DIRS = [
    [None, RIGHT, RIGHT, LEFT], #R
    [LEFT, None, RIGHT, RIGHT], #D
    [RIGHT, LEFT, None, RIGHT], #L
    [RIGHT, RIGHT, LEFT, None], #U
]

def get_turn(ori, tgt_ori):
    return TURN_DIRS[ori][tgt_ori]
